import torch
import torch.nn.functional as F
from torch import nn, optim
import logging
from torchmetrics import Accuracy
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
from src.pl_model.distillation import Distilltion
import wandb
import os
from omegaconf import OmegaConf
import copy
from src.utils.load_models import load_models, load_model3, load_model_qktD
from hydra.utils import instantiate


log = logging.getLogger(__name__)


class MyDistillationMutlipleTeachers_QKT_fc(Distilltion):
    def __init__(self, cfg, l_model: nn.Module = None):
        super().__init__(cfg=cfg)
        self.save_hyperparameters(cfg)
        self.cfg = cfg  # Save cfg to self
        self.automatic_optimization = False
        self.learner_client = cfg.learner_client
        print(f"self.learner_client: {self.learner_client} ")
        self.teacher_client = cfg.teacher_client
        self.data_dists_vect = self.get_data_dists_vect(cfg)

        self.teacher_model = load_models(clients_ids=cfg.teacher_client, clients=cfg.clients,
                                         model=instantiate(cfg.model))
        self.learner_model = copy.deepcopy(self.teacher_model[cfg.learner_client])

        if l_model:
            print(f"starting from the provided student model...")
            self.learner_model = copy.deepcopy(l_model)
        self.num_classes = cfg.num_classes
        self.run_id = cfg.train_exp_id
        self.learner_client = self.hparams.learner_client
        self.teacher_client = self.hparams.teacher_client
        self.noise_threshold = cfg.noise_threshold
        self.T = self.hparams.KL_temperature
        self.qkt_unweighted_teachers = self.hparams.qkt_unweighted_teachers
        self.goal_class = self.hparams.goal_class
        print(f"self.goal_classes: {self.goal_class}")
        self.mask = []

        # Use get_alphas_data_free if cfg.alpha_data_free is True
        if self.hparams.alpha_data_free:
            if self.hparams.data_free_option == 1:
                self.alphas = torch.stack(self.get_alphas_data_free(self.teacher_model))
                print(f"using the data_free option1")
            else:
                self.alphas = torch.stack(self.get_alphas_data_free3(self.teacher_model))

            self.learner_alpha = [1 for _ in range(self.num_classes)]
            # if cfg.personalized_qkt: #TODO: try this
            #     self.learner_alpha = [2 for _ in range(self.num_classes)]
        else:
            self.alphas = torch.stack(self.get_alphas(self.hparams.teacher_client))
            self.learner_alpha = 1 - torch.sum(self.alphas, axis=0)
            self.alphas = self.adjust_alphas(self.alphas)  # for the QKT

        self.alpha = self.alphas


        # the following params in this section are for the param masking exp that aims to mitigate forgeting,
        # they are not used in our main experiments and results

        self.soft_mask_value = self.hparams.get('soft_mask_value', 0.1)
        self.top_Z_percent = self.hparams.get('top_Z_percent', 0.3)
        self.hard_mask = self.hparams.get('hard_mask', False)
        self.use_l2_norm = self.hparams.get('use_l2_norm', True)
        self.grad_norm_type = self.hparams.get('grad_norm_type', 'grad')  # grad or grad_param_ratio

        # ######

        self.best_query_class_acc_gain_epoch = 0
        self.least_forgetting_epoch = 0
        self.best_val_acc_epoch = 0
        self.best_simple_weighted_accuracy_epoch = 0
        self.best_uniform_accuracy_epoch = 0
        self.best_val_acc = float('-inf')
        self.best_query_class_acc_gain = float('-inf')
        self.least_forgetting = float('-inf')
        self.best_simple_weighted_accuracy = float('-inf')
        self.best_uniform_accuracy = float('-inf')

        self.alpha = self.alphas
        if not self.qkt_unweighted_teachers:
            log.info(f"Alphas are:\n{self.alphas}")
            log.info(f"Learner Alpha is:\n{self.learner_alpha}")
        else:
            log.info(f"Using qkt_unweighted_teachers!")

        if self.hard_mask:
            print("Hard masking!")
        # Freeze all layers except classification head
        if self.hparams.get('freeze_backbone', False):
            self.freeze_backbone()

    def freeze_backbone(self):
        for name, param in self.learner_model.named_parameters():
            if "fc" not in name and "classifier" not in name:  # Adjust for both ResNet and FedNet
                param.requires_grad = False
            else:
                param.requires_grad = True
        print("Backbone layers are frozen, only classification head is trainable.")

    def get_alphas(self, teachers):
        tas = [self.get_per_class_accuracy(teacher) for teacher in teachers]
        la = torch.tensor(self.get_per_class_accuracy(self.learner_client), dtype=torch.float)
        alphas = [torch.nan_to_num(ta / (sum(tas) + la)) for ta in tas]
        return alphas

    def get_alphas_data_free(self, teacher_models):
        """
        Calculates alphas with only the classes present in the learner model or the goal class.
        Classes not present are set to zero.
        Teachers with binary class weights not zero for learner or goal classes get equal weight.
        Goal classes have slightly increased weight.
        """
        binary_teacher_weights = self.identify_binary_teacher_weights(teacher_models)
        la = torch.tensor(self.get_per_class_accuracy(self.learner_client), dtype=torch.float)

        alphas = []

        # Create a mask for the goal classes and learner classes
        mask = torch.zeros(self.num_classes)
        for i in range(self.num_classes):
            if la[i] != 0 or i in self.goal_class:
                mask[i] = 1

        self.mask = mask
        print(f"self.mask: {self.mask}")


        # Calculate alphas based on the mask and binary teacher weights
        for teacher_idx, teacher_weights in enumerate(binary_teacher_weights):
            alpha = torch.zeros(self.num_classes)
            for i in range(self.num_classes):
                # Debugging logs to check the conditions
                log.info(f"Teacher {teacher_idx}, Class {i}, Binary Weight: {teacher_weights[i]}, Mask: {mask[i]}")
                if teacher_weights[i] == 1:
                    alpha[i] = mask[i]
                    log.info(f"Setting alpha[{i}] for teacher {teacher_idx} to {mask[i]}")

            print(f"alpha: {alpha}")
            alphas.append(alpha)

        # Adjust the alpha values for goal classes to slightly increase their weight
        for alpha in alphas:
            for g in self.goal_class:
                if alpha[g] != 0:
                    alpha[g] *= self.hparams.goal_class_boost

        return alphas

    def get_alphas_data_free2(self, teachers):
        """
        Calculates alphas for each teacher model by considering only the classes present in the learner model or the goal classes.
        Classes not present in the learner model or the goal classes are set to zero.
        Teachers are only considered if they have any of the goal classes with a binary weight of 1.
        For these teachers, irrelevant classes are masked out, and only the classes of the learner and the goal classes are kept.
        Goal classes have their weights increased.

        Parameters:
        teachers (list): List of teacher models.

        Returns:
        alphas (list of tensors): List of alpha tensors, one for each teacher, representing the importance of each class.
        """
        binary_teacher_weights = self.identify_binary_teacher_weights(teachers,threshold=self.noise_threshold)
        la = torch.tensor(self.get_per_class_accuracy(self.learner_client), dtype=torch.float)

        alphas = []

        # Create a mask for the goal classes and learner classes
        mask = torch.zeros(self.num_classes)
        for i in range(self.num_classes):
            if la[i] != 0 or i in self.goal_class:
                mask[i] = 1

        print(f"mask: {mask}")

        # Calculate alphas based on the mask and binary teacher weights
        for teacher_idx, teacher_weights in enumerate(binary_teacher_weights):
            alpha = torch.zeros(self.num_classes)
            has_goal_class = any(teacher_weights[i] == 1 for i in self.goal_class)

            if has_goal_class:
                for i in range(self.num_classes):
                    # Debugging logs to check the conditions
                    log.info(f"Teacher {teacher_idx}, Class {i}, Binary Weight: {teacher_weights[i]}, Mask: {mask[i]}")
                    if teacher_weights[i] == 1:
                        alpha[i] = mask[i]
                        log.info(f"Setting alpha[{i}] for teacher {teacher_idx} to {mask[i]}")

                # Adjust the alpha values for goal classes to slightly increase their weight
                for g in self.goal_class:
                    if alpha[g] != 0:
                        alpha[g] *= self.hparams.goal_class_boost
            print(f"alpha: {alpha}")
            alphas.append(alpha)

        return alphas

    def get_alphas_data_free3(self, teacher_models):
        """
        Calculates alphas for each teacher model by considering only the classes present in the learner model or the goal classes.
        Classes not present in the learner model or the goal classes are set to zero.
        All teachers have alpha, which is the mask indicating relevant classes.
        Goal classes have their weights slightly increased.
        The teacher with the same index as the learner client is always included to mitigate forgetting.

        Parameters:
        teachers (list): List of teacher models.

        Returns:
        alphas (list of tensors): List of alpha tensors, one for each teacher, representing the importance of each class.
        """

        # Identify binary teacher weights
        binary_teacher_weights = self.identify_binary_teacher_weights(teacher_models)

        # Use number of samples or per-class accuracy to identify learner classes
        if self.cfg.use_number_of_samples:
            # Use data_dists_vect to identify learner classes if configured
            la = torch.tensor(self.data_dists_vect[f'client_{self.learner_client}'])
        else:
            # Otherwise, use the per-class accuracy
            la = torch.tensor(self.get_per_class_accuracy(self.learner_client), dtype=torch.float)

        alphas = []

        # Create a mask based on the cfg.only_goal_classes setting
        mask = torch.zeros(self.num_classes)
        learner_client_mask = torch.zeros(self.num_classes)
        if self.cfg.only_goal_classes:
            # If only_goal_classes is true, set the mask only for goal classes with the boost
            for i in range(self.num_classes):
                if i in self.goal_class:
                    mask[i] = self.hparams.goal_class_boost
                if la[i] != 0:
                    learner_client_mask[i] = 1

        else:
            # If only_goal_classes is false, include both learner classes and goal classes
            for i in range(self.num_classes):
                if i in self.goal_class:
                    mask[i] = self.hparams.goal_class_boost
                if la[i] != 0:
                    mask[i] = 1
                    learner_client_mask[i] = 1

        # Check if any teacher was detected to have the goal class
        goal_class_detected = any(
            any(teacher_weights[i] == 1 for i in self.goal_class)
            for teacher_weights in binary_teacher_weights
        )

        # Assign alphas based on whether the goal class was detected
        if goal_class_detected:
            # If any teacher has the goal class, assign alphas accordingly
            for teacher_idx, teacher_weights in enumerate(binary_teacher_weights):
                has_goal_class = any(teacher_weights[i] == 1 for i in self.goal_class)
                if has_goal_class or teacher_idx == self.learner_client:
                    alpha = mask.clone()

                    if self.cfg.only_goal_classes and self.hparams.copy_of_self_as_teacher and teacher_idx == self.learner_client:
                            alpha = learner_client_mask.clone()
                else:
                    alpha = torch.zeros(self.num_classes)

                alphas.append(alpha)
        else:
            # If no teacher has the goal class, include all teachers
            log.info("No teacher with goal class found. Adding all teachers.")
            alphas = [mask.clone() for _ in binary_teacher_weights]
            alphas[self.learner_client] = learner_client_mask.clone()

        # Log the list of used teachers
        used_teachers = [i for i, _ in enumerate(binary_teacher_weights) if any(alphas[i] != 0)]
        # self.logger.experiment.summary["used_teachers"] = used_teachers
        log.info(f"Used teachers: {used_teachers}")

        return alphas


    def get_alphas_data_free3_b(self, teacher_models):
        """
        Calculates alphas for each teacher model by considering only the classes present in the learner model or the goal classes.
        Classes not present in the learner model or the goal classes are set to zero.
        All teachers have alpha, which is the mask indicating relevant classes.
        Goal classes have their weights slightly increased.
        The teacher with the same index as the learner client is always included to mitigate forgetting.

        Parameters:
        teachers (list): List of teacher models.

        Returns:
        alphas (list of tensors): List of alpha tensors, one for each teacher, representing the importance of each class.
        """

        # Identify binary teacher weights
        binary_teacher_weights = self.identify_binary_teacher_weights(teacher_models)

        # Use number of samples or per-class accuracy to identify learner classes
        if self.cfg.use_number_of_samples:
            # Use data_dists_vect to identify learner classes if configured
            la = torch.tensor(self.data_dists_vect[f'client_{self.learner_client}'])
        else:
            # Otherwise, use the per-class accuracy
            la = torch.tensor(self.get_per_class_accuracy(self.learner_client), dtype=torch.float)

        alphas = []

        # Create a mask based on the cfg.only_goal_classes setting
        mask = torch.zeros(self.num_classes)
        learner_client_mask = torch.zeros(self.num_classes)
        if self.cfg.only_goal_classes:
            # If only_goal_classes is true, set the mask only for goal classes with the boost
            for i in range(self.num_classes):
                if i in self.goal_class:
                    mask[i] = self.hparams.goal_class_boost
                if la[i] != 0:
                    learner_client_mask[i] = 1

        else:
            # If only_goal_classes is false, include both learner classes and goal classes
            for i in range(self.num_classes):
                if i in self.goal_class:
                    mask[i] = self.hparams.goal_class_boost
                if la[i] != 0:
                    mask[i] = 1
                    learner_client_mask[i] = 1

        # Check if any teacher was detected to have the goal class
        goal_class_detected = any(
            any(teacher_weights[i] == 1 for i in self.goal_class)
            for teacher_weights in binary_teacher_weights
        )

        # Assign alphas based on whether the goal class was detected
        if goal_class_detected:
            # If any teacher has the goal class, assign alphas accordingly
            for teacher_idx, teacher_weights in enumerate(binary_teacher_weights):
                has_goal_class = any(teacher_weights[i] == 1 for i in self.goal_class)
                if has_goal_class or teacher_idx == self.learner_client:
                    alpha = mask.clone()

                    if self.cfg.only_goal_classes and self.hparams.copy_of_self_as_teacher and teacher_idx == self.learner_client:
                            alpha = learner_client_mask.clone()
                else:
                    alpha = torch.zeros(self.num_classes)

                alphas.append(alpha)
        else:
            # If no teacher has the goal class, include all teachers
            log.info("No teacher with goal class found. Adding all teachers.")
            alphas = [mask.clone() for _ in binary_teacher_weights]
            alphas[self.learner_client] = learner_client_mask.clone()

        # Log the list of used teachers
        used_teachers = [i for i, _ in enumerate(binary_teacher_weights) if any(alphas[i] != 0)]
        # self.logger.experiment.summary["used_teachers"] = used_teachers
        log.info(f"Used teachers: {used_teachers}")

        return alphas

    def get_alphas_data_free4(self, teachers):
        """
        Calculates alphas for each teacher model by considering only the classes present in the learner model or the goal classes.
        Classes not present in the learner model or the goal classes are set to zero.
        All teachers have the same alpha, which is the mask indicating relevant classes.
        Goal classes have their weights increased.

        Parameters:
        teachers (list): List of teacher models.

        Returns:
        alphas (list of tensors): List of alpha tensors, one for each teacher, representing the importance of each class.
        """

        print(f">> using (get_alphas_data_free4)")
        binary_teacher_weights = self.identify_binary_teacher_weights(teachers)
        la = torch.tensor(self.get_per_class_accuracy(self.learner_client), dtype=torch.float)

        alphas = []

        # Create a mask for the goal classes and learner classes
        mask = torch.zeros(self.num_classes)
        for i in range(self.num_classes):
            if la[i] != 0 or i in self.goal_class:
                mask[i] = 1

        print(f"mask: {mask}")

        # Adjust the mask values for goal classes to slightly increase their weight
        for g in self.goal_class:
            if mask[g] != 0:
                mask[g] *= self.hparams.goal_class_boost

        # Assign the same mask to all teachers if they have any goal class with binary weight 1
        for teacher_idx, teacher_weights in enumerate(binary_teacher_weights):
            # has_goal_class = any(teacher_weights[i] == 1 for i in self.goal_class)
            alpha = mask.clone()
            print(f"alpha for teacher {teacher_idx}: {alpha}")
            alphas.append(alpha)

        return alphas

    def adjust_alphas(self, alphas):
        z = torch.zeros_like(alphas)
        for i, alpha in enumerate(alphas):
            for g in self.goal_class:
                if (alpha[g] != 1 and alpha[g] != 0 and alpha[g] <= 0.8):
                    z[i][g] = alpha[g].clone().detach() + 0.2
                elif (alpha[g] == 1):
                    z[i][g] = 1
                else:
                    z[i][g] = alpha[g].clone().detach()
            for l in range(self.num_classes):
                if (self.learner_alpha[l] != 0 and alpha[l] != 0):
                    z[i][l] = alpha[l].clone().detach()
        return z

    def on_fit_start(self):
        self.learner_alpha = torch.tensor(self.learner_alpha, dtype=torch.float, device=self.device)
        self.alphas = [torch.tensor(a, dtype=torch.float, device=self.device) for a in self.alphas]
        self.teacher_model = [t.to(self.device) for t in self.teacher_model]
        # Compute Fisher Information Matrix and L2 norms at the start of training
        self.fisher_matrix = self.compute_fisher_information_fc_individual()
        self.importance_scores = self.compute_importance_fc()
        self.prev_param = {name: param.clone().detach() for name, param in self.learner_model.named_parameters()}
        # Store initial parameter values
        self.initial_params = {name: param.clone().detach() for name, param in self.learner_model.named_parameters()}
        # Compute masks before training
        print(f"computing mask_l2_norm...")
        self.masks_l2 = self.compute_mask_l2_norm_gradients()
        print(f"computing mask_fisher_information...")
        self.masks_fisher = self.compute_mask_fisher_information()

        if self.cfg.measure_pre_transfer_acc:
            self.pre_transfer_acc = self.calculate_per_class_accuracy(self.learner_model)


    def calculate_per_class_accuracy(self, model):
        """
        Calculates per-class accuracy for a given model using the test dataloader.
        """
        self.learner_model.eval()
        correct = torch.zeros(self.num_classes).to(self.device)
        total = torch.zeros(self.num_classes).to(self.device)

        with torch.no_grad():
            for x, y in self.trainer.datamodule.test_dataloader():
                x, y = x.to(self.device), y.to(self.device)
                outputs = model(x)
                _, predicted = torch.max(outputs, 1)
                c = (predicted == y).squeeze()

                for i in range(len(y)):
                    label = y[i]
                    correct[label] += c[i].item()
                    total[label] += 1

        per_class_accuracy = correct / total
        return per_class_accuracy.cpu().numpy()

    def calculate_client_all_accuracies(self, per_class_acc):
        """
        Calculates various custom accuracy metrics for a specified client including measures before and after knowledge transfer.

        Parameters:
        - cfg: Configuration object containing client and training configurations including `learner_client` and `query_classes`.
        - per_class_acc (list): List of accuracies per class for the specified client after training.
        - data_dists_vectorized (dict): Dictionary mapping clients to their class distribution data.

        Returns:
        - tuple: Contains the calculated accuracies using different strategies, including the effect of forgetting.
        """

        learner_client = self.learner_client

        client_name = f'client_{learner_client}'
        data_dists_vectorized = self.data_dists_vect
        class_distribution = data_dists_vectorized[client_name]
        query_classes = self.goal_class

        # Initialize variables for each strategy
        total_uniform_acc = total_weighted_acc = total_query_class_acc = total_local_class_acc = 0
        count_uniform_classes = total_weight = total_local_weight = 0

        num_classes = len(per_class_acc)
        print(f"num_classes: {num_classes}")
        query_class_acc = [per_class_acc[i] for i in query_classes]

        if self.cfg.measure_pre_transfer_acc:
            pre_transfer_acc = self.pre_transfer_acc
        else:
            train_run_id = self.run_id
            api = wandb.Api()
            train_run = api.run(train_run_id)
            train_run_summary = train_run.summary._json_dict
            pre_transfer_acc = train_run_summary[f'client-{learner_client}/per_class_test_acc']

        query_class_acc_gain = [(per_class_acc[i] - pre_transfer_acc[i]) for i in query_classes]

        for cls_index in range(num_classes):
            if class_distribution[cls_index] > 0 or cls_index in query_classes:
                # Calculate uniform accuracy
                total_uniform_acc += per_class_acc[cls_index]
                count_uniform_classes += 1

                # Simple weighted accuracy (query classes get weight 1)
                weight = 1 if cls_index in query_classes else class_distribution[cls_index] / sum(class_distribution)
                total_weighted_acc += weight * per_class_acc[cls_index]
                total_weight += weight

                # Local classes accuracy excluding query classes
                if cls_index not in query_classes:
                    local_weight = class_distribution[cls_index] / sum(class_distribution)
                    total_local_class_acc += local_weight * per_class_acc[cls_index]
                    total_local_weight += local_weight

        # Compute accuracies for each strategy
        uniform_accuracy = total_uniform_acc / count_uniform_classes if count_uniform_classes > 0 else 0
        simple_weighted_accuracy = total_weighted_acc / total_weight if total_weight > 0 else 0
        query_classes_accuracy = sum(query_class_acc) / len(query_classes) if query_classes else 0
        query_classes_acc_gain = sum(query_class_acc_gain) / len(query_class_acc_gain) if query_classes else 0
        local_classes_accuracy = total_local_class_acc / total_local_weight if total_local_weight > 0 else 0
        forgetting = sum((per_class_acc[j] - pre_transfer_acc[j]) for j in range(len(per_class_acc)) if
                         (per_class_acc[j] - pre_transfer_acc[j]) < 0) / len(
            [accuracy for accuracy in pre_transfer_acc if accuracy > 0])

        print(f"uniform_accuracy: {uniform_accuracy}")
        print(f"simple_weighted_accuracy: {simple_weighted_accuracy}")
        print(f"query_classes_accuracy: {query_classes_accuracy}")
        print(f"query_classes_acc_gain: {query_classes_acc_gain}")
        print(f"local_classes_accuracy: {local_classes_accuracy}")
        print(f"forgetting: {forgetting}")

        return (
            uniform_accuracy, simple_weighted_accuracy, query_classes_accuracy, query_classes_acc_gain,
            local_classes_accuracy,
            forgetting)

    def get_per_class_accuracy(self, client_id):
        """
        Fetches the per-class accuracy for a given client ID from Weights & Biases.
        """
        api = wandb.Api()
        run_id = self.run_id
        run = api.run(run_id)
        name = f"client-{client_id}/val_per_class_acc"
        val_acc = run.summary[name]
        val_acc = torch.tensor(val_acc)
        return val_acc



    def compute_fisher_information_fc_individual(self):
        fisher_matrix = {}
        for name, param in self.learner_model.named_parameters():
            if 'fc.weight' in name or 'classifier.weight' in name:  # Adjust for both ResNet and FedNet
                fisher_matrix[name] = torch.zeros_like(param)
        self.learner_model.eval()
        dataloader = self.trainer.datamodule.train_dataloader()
        num_batches = len(dataloader)
        for batch_idx, batch in enumerate(dataloader):
            x, y = batch
            x, y = x.to(self.device), y.to(self.device)
            if len(y.shape) > 1:
                y = y.squeeze(1)  # Squeeze the labels to ensure they are 1D  (for medMNIST dataset)

            self.learner_model.zero_grad()
            output = self.learner_model(x)
            loss = F.cross_entropy(output, y)
            loss.backward()
            for name, param in self.learner_model.named_parameters():
                if ('fc.weight' in name or 'classifier.weight' in name) and param.grad is not None:  # Adjust
                    fisher_matrix[name] += param.grad ** 2 / num_batches
            if batch_idx % 10 == 0:
                print(f"Processed batch {batch_idx + 1}/{num_batches}")
        print(f"Computed Fisher Information Matrix for fc layer: {fisher_matrix}")
        # Print statistics
        for name, importance in fisher_matrix.items():
            print(f"Statistics for {name}:")
            print(f"Mean: {importance.mean().item()}")
            print(f"Standard Deviation: {importance.std().item()}")
            print(f"Min: {importance.min().item()}")
            print(f"Max: {importance.max().item()}")
        return fisher_matrix

    def compute_importance_fc(self):
        importance_scores = {}
        epsilon = 1e-10  # Small constant to prevent division by zero
        for name, param in self.learner_model.named_parameters():
            if 'fc.weight' in name or 'classifier.weight' in name:  # Adjust for both ResNet and FedNet
                importance_scores[name] = torch.zeros_like(param)
        self.learner_model.eval()
        dataloader = self.trainer.datamodule.train_dataloader()
        num_batches = len(dataloader)
        for batch_idx, batch in enumerate(dataloader):
            x, y = batch
            x, y = x.to(self.device), y.to(self.device)
            if len(y.shape) > 1:
                y = y.squeeze(1)  # Squeeze the labels to ensure they are 1D  (for medMNIST dataset)

            self.learner_model.zero_grad()
            output = self.learner_model(x)
            loss = F.cross_entropy(output, y)
            loss.backward()
            for name, param in self.learner_model.named_parameters():
                if ('fc.weight' in name or 'classifier.weight' in name) and param.grad is not None:  # Adjust
                    if self.grad_norm_type == 'grad_param_ratio':
                        importance_scores[name] += (torch.abs(param.grad.data) / (
                                torch.abs(param.data) + epsilon)) / num_batches
                        print(f"Calculated importance score based on grad_param_ratio")
                    else:
                        importance_scores[name] += param.grad ** 2 / num_batches
                        print(f"Calculated importance score based on grad only")
            if batch_idx % 10 == 0:
                print(f"Processed batch {batch_idx + 1}/{num_batches}")
        print(f"Computed L2 Importance Scores for fc layer: {importance_scores}")
        # Print statistics
        for name, importance in importance_scores.items():
            print(f"L2 Statistics for {name}:")
            print(f"Mean: {importance.mean().item()}")
            print(f"Standard Deviation: {importance.std().item()}")
            print(f"Min: {importance.min().item()}")
            print(f"Max: {importance.max().item()}")
        return importance_scores

    def compute_mask_l2_norm_gradients(self):
        masks = {}
        for name, param in self.learner_model.named_parameters():
            if 'fc.weight' in name or 'classifier.weight' in name:  # Adjust for both ResNet and FedNet
                importance = self.importance_scores[name]
                flat_importance = importance.flatten()
                print(f"L2 norm scores for {name}: {flat_importance}")
                print(f"top_Z_percent: {self.top_Z_percent}")
                k = max(int(len(flat_importance) * self.top_Z_percent), 1)
                top_k_values, _ = torch.topk(flat_importance, k)
                kth_value = top_k_values[-1] if self.top_Z_percent < 1.0 else float('-inf')
                print(f"The threshold (kth_value) for {name} based on L2 norm: {kth_value}")
                mask = importance >= kth_value
                if self.hard_mask:
                    masks[name] = (importance < kth_value).float()
                else:
                    masks[name] = mask.float() * self.soft_mask_value + (1 - mask.float())
                print(f"Soft mask for {name} based on L2 norm: {masks[name]}")
                num_masked_weights = torch.sum(mask)
                print(f"Number of weights masked in {name} based on L2 norm: {num_masked_weights}")
        return masks

    def compute_mask_fisher_information(self):
        masks = {}
        for name, param in self.learner_model.named_parameters():
            if 'fc.weight' in name or 'classifier.weight' in name:  # Adjust for both ResNet and FedNet
                importance = self.fisher_matrix[name]
                flat_importance = importance.flatten()
                print(f"Importance scores for {name}: {flat_importance}")
                print(f"top_Z_percent: {self.top_Z_percent}")
                k = max(int(len(flat_importance) * self.top_Z_percent), 1)
                top_k_values, _ = torch.topk(flat_importance, k)
                kth_value = top_k_values[-1] if self.top_Z_percent < 1.0 else float('-inf')
                print(f"The threshold (kth_value) for {name}: {kth_value}")
                mask = importance >= kth_value
                if self.hard_mask:
                    masks[name] = (importance < kth_value).float()
                else:
                    masks[name] = mask.float() * self.soft_mask_value + (1 - mask.float())
                print(f"Soft mask for {name}: {masks[name]}")
                num_masked_weights = torch.sum(mask)
                print(f"Number of weights masked in {name}: {num_masked_weights}")
        return masks

    def training_step(self, batch, batch_idx):
        x, y = batch
        x, y = x.to(self.device), y.to(self.device)
        if len(y.shape) > 1:
            y = y.squeeze(1)  # Squeeze the labels to ensure they are 1D  (for medMNIST dataset)

        learner_logits = self(x)
        onehot_y = F.one_hot(y, self.num_classes).to(torch.float)
        ce = F.kl_div(F.log_softmax(learner_logits, dim=1), onehot_y, reduction='none')

        if not self.qkt_unweighted_teachers and not self.hparams.alpha_data_free:
            ce = ce * self.learner_alpha
        ce = ce.sum() / ce.size()[0]
        self.log(f"{self.exp_name}/learner-ce_loss", ce, on_step=True, on_epoch=False, prog_bar=True)

        divergences = 0
        for (alpha, t_model) in zip(self.alphas, self.teacher_model):
            with torch.no_grad():
                teacher_logits = t_model(x)
            divergence = F.kl_div(F.log_softmax(learner_logits / self.T, dim=1),
                                  F.softmax(teacher_logits / self.T, dim=1), reduction='none')

            if not self.qkt_unweighted_teachers:
                divergence = divergence * alpha
            divergence = divergence.sum() / divergence.size()[0] * self.T * self.T
            divergences += divergence
        self.log(f"{self.exp_name}/learner-kl_loss", divergences, on_step=True, on_epoch=False, prog_bar=True)

        # Calculate total loss
        total_loss = ce + divergences
        self.manual_backward(total_loss)
        # Save and print gradients before masking
        grad_before = {}
        for name, param in self.learner_model.named_parameters():
            if ('fc.weight' in name or 'classifier.weight' in name) and param.grad is not None:  # Adjust
                grad_before[name] = param.grad.clone()
                print(f"Grad for {name} before masking: {grad_before[name]}")
        # Apply masks after backward
        if self.grad_norm_type == 'grad_param_ratio' or self.use_l2_norm:
            masks_fc = self.masks_l2
        else:
            masks_fc = self.masks_fisher
        for name, param in self.learner_model.named_parameters():
            if ('fc.weight' in name or 'classifier.weight' in name) and name in masks_fc and param.grad is not None:  # Adjust
                param.grad *= masks_fc[name]
                print(f"Grad for {name} after masking: {param.grad}")
        # Save weights before optimizer step
        weights_before = {}
        for name, param in self.learner_model.named_parameters():
            if 'fc.weight' in name or 'classifier.weight' in name:  # Adjust
                weights_before[name] = param.clone().detach()
        # Perform optimizer step manually
        self.trainer.optimizers[0].step()
        self.trainer.optimizers[0].zero_grad()

        # Reset weights to initial values after optimizer step if hard masking
        if self.hard_mask:
            for name, param in self.learner_model.named_parameters():
                if 'fc.weight' in name or 'classifier.weight' in name:  # Adjust
                    # Create a mask where important weights are set to 0 (to keep them unchanged)
                    important_mask = (self.masks_l2[name] == 0)
                    # Copy original weights where important, blend with updated weights where not important
                    temp_data = param.data.clone()  # Temporary storage to avoid in-place operation conflicts
                    param.data[important_mask].copy_(self.initial_params[name][important_mask])
                    param.data[~important_mask].copy_(temp_data[~important_mask])
        # Print weights after optimizer step
        for name, param in self.learner_model.named_parameters():
            if 'fc.weight' in name or 'classifier.weight' in name:  # Adjust
                weights_after = param.clone().detach()
                print(f"Weights for {name} before optimizer step: {weights_before[name]}")
                print(f"Weights for {name} after optimizer step: {weights_after}")
        # Logging and diagnostics
        preds = torch.argmax(learner_logits, dim=1)
        acc = self.train_acc(preds, y)
        self.log(f"{self.exp_name}/train_acc", acc, on_step=False, on_epoch=True, prog_bar=False)
        return {"loss": total_loss, "preds": preds, "targets": y}

    def get_data_dists_vect(self, cfg):
        num_classes = self.num_classes  # TODO: Fix this to be set automatically
        total_num_samples_per_class = defaultdict(int)
        data_dists_vectorized = {}
        for client, info in cfg.clients.items():
            data_dist = cfg.clients[client]["train_data_distribution"]
            data_dist_vectorized = np.array(
                [data_dist.get(f"{cls_idx}") if data_dist.get(f"{cls_idx}") else 0 for cls_idx in range(num_classes)])
            data_dists_vectorized[client] = data_dist_vectorized
            for cls_idx, count in data_dist.items():
                total_num_samples_per_class[cls_idx] += count
        total_num_samples = sum(total_num_samples_per_class.values())
        log.info(f"data_dists_vectorized: {data_dists_vectorized}")
        log.info(f"total_num_samples_per_class: {total_num_samples_per_class}")
        log.info(f"total_num_samples: {total_num_samples}")
        return data_dists_vectorized

    def validation_epoch_end(self, outputs):
        acc = self.val_acc.compute()  # get val accuracy from current epoch
        self.val_acc_best.update(acc)
        confusion_matrix = self.val_confusion_matrix.compute()
        confusion_matrix = confusion_matrix / confusion_matrix.sum(axis=1)
        self.per_class_val_acc = np.diag(confusion_matrix.cpu().detach().numpy())
        self.private_val_acc = acc
        uniform_accuracy, simple_weighted_accuracy, query_classes_accuracy, query_classes_acc_gain, local_classes_accuracy, forgetting = self.calculate_client_all_accuracies(
            self.per_class_val_acc)
        # Define checkpoint paths
        val_acc_ckpt_path = os.path.join(self.trainer.default_root_dir, "best_val_acc.ckpt")
        query_class_acc_gain_ckpt_path = os.path.join(self.trainer.default_root_dir,
                                                      "val_best_query_class_acc_gain.ckpt")
        least_forgetting_ckpt_path = os.path.join(self.trainer.default_root_dir, "val_least_forgetting.ckpt")
        simple_weighted_accuracy_ckpt_path = os.path.join(self.trainer.default_root_dir,
                                                          "val_best_simple_weighted_accuracy.ckpt")
        uniform_accuracy_ckpt_path = os.path.join(self.trainer.default_root_dir,
                                                  "val_best_uniform_accuracy.ckpt")  # Add for uniform accuracy
        # Ensure the model is in eval mode
        self.eval()
        # Save the current model state
        current_state_dict = {k: v.clone().detach().cpu() for k, v in self.learner_model.state_dict().items()}
        # Save the best checkpoints for different metrics (and delete the prev)
        if acc > self.best_val_acc:
            self.best_val_acc = acc
            self.best_val_acc_epoch = self.current_epoch
            if os.path.exists(val_acc_ckpt_path):
                os.remove(val_acc_ckpt_path)
            torch.save(current_state_dict, val_acc_ckpt_path)
        if query_classes_acc_gain > self.best_query_class_acc_gain:
            self.best_query_class_acc_gain = query_classes_acc_gain
            self.best_query_class_acc_gain_epoch = self.current_epoch
            if os.path.exists(query_class_acc_gain_ckpt_path):
                os.remove(query_class_acc_gain_ckpt_path)
            torch.save(current_state_dict, query_class_acc_gain_ckpt_path)
        if forgetting > self.least_forgetting:  # negative values, larger means less forgetting
            self.least_forgetting = forgetting
            self.least_forgetting_epoch = self.current_epoch
            if os.path.exists(least_forgetting_ckpt_path):
                os.remove(least_forgetting_ckpt_path)
            torch.save(current_state_dict, least_forgetting_ckpt_path)
        if simple_weighted_accuracy > self.best_simple_weighted_accuracy:
            self.best_simple_weighted_accuracy = simple_weighted_accuracy
            self.best_simple_weighted_accuracy_epoch = self.current_epoch
            if os.path.exists(simple_weighted_accuracy_ckpt_path):
                os.remove(simple_weighted_accuracy_ckpt_path)
            torch.save(current_state_dict, simple_weighted_accuracy_ckpt_path)
        if uniform_accuracy > self.best_uniform_accuracy:  # Add for uniform accuracy
            self.best_uniform_accuracy = uniform_accuracy
            self.best_uniform_accuracy_epoch = self.current_epoch
            if os.path.exists(uniform_accuracy_ckpt_path):
                os.remove(uniform_accuracy_ckpt_path)
            torch.save(current_state_dict, uniform_accuracy_ckpt_path)
        self.logger.experiment.summary[f"{self.exp_name}/val_acc"] = acc
        self.logger.experiment.summary[f"{self.exp_name}/val_per_class_acc"] = self.per_class_val_acc.tolist()
        self.logger.experiment.summary[f"{self.exp_name}/val_simple_weighted_accuracy"] = simple_weighted_accuracy
        self.logger.experiment.summary[
            f"{self.exp_name}/val_uniform_accuracy"] = uniform_accuracy  # Add for uniform accuracy
        self.logger.experiment.summary[f"{self.exp_name}/val_query_classes_acc_gain"] = query_classes_acc_gain
        self.logger.experiment.summary[f"{self.exp_name}/val_forgetting"] = forgetting
        log.info(f"{self.exp_name}/val_acc: {acc}")
        log.info(f"{self.exp_name}/val_per_class_acc: {self.per_class_val_acc.tolist()}")
        log.info(f"{self.exp_name}/val_simple_weighted_accuracy: {simple_weighted_accuracy}")
        log.info(f"{self.exp_name}/val_uniform_accuracy: {uniform_accuracy}")  # Add for uniform accuracy
        log.info(f"{self.exp_name}/val_query_classes_acc_gain: {query_classes_acc_gain}")
        log.info(f"{self.exp_name}/val_forgetting: {forgetting}")
        self.log(f"{self.exp_name}/val_simple_weighted_accuracy", simple_weighted_accuracy, on_epoch=True)
        self.log(f"{self.exp_name}/val_uniform_accuracy", uniform_accuracy, on_epoch=True)  # Add for uniform accuracy
        self.log(f"{self.exp_name}/val_query_classes_acc_gain", query_classes_acc_gain, on_epoch=True)
        self.log(f"{self.exp_name}/val_forgetting", forgetting, on_epoch=True)
        self.log(f"{self.exp_name}/best_val_acc", self.best_val_acc, on_epoch=True)
        self.log(f"{self.exp_name}/val_best_simple_weighted_accuracy", self.best_simple_weighted_accuracy,
                 on_epoch=True)
        self.log(f"{self.exp_name}/val_best_uniform_accuracy", self.best_uniform_accuracy,
                 on_epoch=True)  # Add for uniform accuracy
        self.log(f"{self.exp_name}/val_best_query_class_acc_gain", self.best_query_class_acc_gain, on_epoch=True)
        self.log(f"{self.exp_name}/val_least_forgetting", self.least_forgetting, on_epoch=True)

    def on_fit_end(self):
        # Save the epochs for the best metrics
        self.logger.experiment.summary[f"{self.exp_name}/best_val_acc_epoch"] = self.best_val_acc_epoch
        self.logger.experiment.summary[
            f"{self.exp_name}/val_best_query_class_acc_gain_epoch"] = self.best_query_class_acc_gain_epoch
        self.logger.experiment.summary[f"{self.exp_name}/val_least_forgetting_epoch"] = self.least_forgetting_epoch
        self.logger.experiment.summary[
            f"{self.exp_name}/val_best_simple_weighted_accuracy_epoch"] = self.best_simple_weighted_accuracy_epoch
        self.logger.experiment.summary[
            f"{self.exp_name}/val_best_uniform_accuracy_epoch"] = self.best_uniform_accuracy_epoch  # Add for uniform accuracy
        self.logger.experiment.summary[f"{self.exp_name}/best_val_acc"] = self.best_val_acc
        self.logger.experiment.summary[
            f"{self.exp_name}/val_best_query_class_acc_gain"] = self.best_query_class_acc_gain
        self.logger.experiment.summary[f"{self.exp_name}/val_least_forgetting"] = self.least_forgetting
        self.logger.experiment.summary[
            f"{self.exp_name}/val_best_simple_weighted_accuracy"] = self.best_simple_weighted_accuracy
        self.logger.experiment.summary[
            f"{self.exp_name}/val_best_uniform_accuracy"] = self.best_uniform_accuracy  # Add for uniform accuracy
        # Save the latest model state at the end of training
        latest_ckpt_path = os.path.join(self.trainer.default_root_dir, "latest.ckpt")
        latest_state_dict = {k: v.clone().detach().cpu() for k, v in self.learner_model.state_dict().items()}
        torch.save(latest_state_dict, latest_ckpt_path)

    def on_test_end(self):  # on fit end will cause issues since next client reinit the neural network weights
        self.val_acc_best.reset()
        self.best_val_acc = float('-inf')
        self.best_query_class_acc_gain = float('-inf')
        self.least_forgetting = float('-inf')
        self.best_simple_weighted_accuracy = float('-inf')
        self.best_uniform_accuracy = float('-inf')  # Add for uniform accuracy

    def on_save_checkpoint(self, checkpoint):
        checkpoint['cfg'] = OmegaConf.to_container(self.cfg, resolve=True)  # Save cfg in the checkpoint

    def on_load_checkpoint(self, checkpoint):
        if 'cfg' in checkpoint:
            self.cfg = OmegaConf.create(checkpoint['cfg'])  # Load cfg from the checkpoint
        else:
            raise KeyError("The checkpoint does not contain 'cfg' key.")
        # Ensure to load any other necessary state here

    @classmethod
    def load_from_checkpoint(cls, checkpoint_path, cfg=None, **kwargs):
        checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
        if 'cfg' not in checkpoint and cfg is not None:
            checkpoint['cfg'] = OmegaConf.to_container(cfg, resolve=True)
        elif 'cfg' in checkpoint and cfg is None:
            cfg = OmegaConf.create(checkpoint['cfg'])
        elif 'cfg' in checkpoint and cfg is not None:
            cfg = OmegaConf.merge(OmegaConf.create(checkpoint['cfg']), cfg)
        return super().load_from_checkpoint(checkpoint_path, cfg=cfg, **kwargs)



    def identify_binary_teacher_weights(self,client_models, num_classes=10):
        binary_teachers_weights = []
        for i, model in enumerate(client_models):
            print(f"\nAnalyzing client {i}'s model")
            client_name = f"client_{i}"


            # Identify underrepresented classes
            print(f"Detected stats:")
            binary_teacher_weights = self.detect_stats(model, threshold=self.noise_threshold)
            binary_teachers_weights.append(binary_teacher_weights)
        return binary_teachers_weights

    def detect_stats(self,model, threshold=0.01):
        synthetic_data, synthetic_labels = self.create_anonymized_data_impressions(model)
        synthetic_data = synthetic_data.to(next(model.parameters()).device)  # Move data to the same device as the model
        output = model(synthetic_data)
        predicted_probs = F.softmax(output, dim=1)

        underrepresented_classes = []
        avg_predicted_probs = []
        binary_teacher_weights = []
        for i in range(predicted_probs.size(1)):
            avg_prob = predicted_probs[:, i].mean().item()
            avg_predicted_probs.append(avg_prob)
            if avg_prob < threshold:
                underrepresented_classes.append(i)
                binary_teacher_weights.append(0)
            else:
                binary_teacher_weights.append(1)

        print(f"model's avg_predicted_probs: {avg_predicted_probs}")
        print(f"model's underrepresented_classes: {underrepresented_classes}")
        print(f"model's binary_teacher_weights: {binary_teacher_weights}")

        return binary_teacher_weights

    def create_anonymized_data_impressions(self,model, num_classes=10):

        synthetic_data = []
        synthetic_labels = []
        for i in range(20):
            # Generate random noise input matching the input shape expected by the model
            input_shape = (3, 224, 224)  # Adjust this shape based on the model's expected input
            noise_input = torch.randn(input_shape)
            synthetic_data.append(noise_input)

        return torch.stack(synthetic_data), torch.tensor(synthetic_labels, dtype=torch.long)

